{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only for WSL\n",
    "import os \n",
    "os.environ[\"PATH\"] = f\"{os.environ['PATH']}:/usr/local/cuda/bin\"\n",
    "os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/wsl/lib:/usr/local/cuda/lib64\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-08-04 17:43:49,433] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please run\n",
      "\n",
      "python -m bitsandbytes\n",
      "\n",
      " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "================================================================================\n",
      "bin /home/oliverwang15/miniconda3/envs/fingpt/lib/python3.9/site-packages/bitsandbytes-0.40.0.post4-py3.9.egg/bitsandbytes/libbitsandbytes_cuda120.so\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/oliverwang15/miniconda3/envs/fingpt/lib/python3.9/site-packages/bitsandbytes-0.40.0.post4-py3.9.egg/bitsandbytes/cuda_setup/main.py:149: UserWarning: /home/oliverwang15/miniconda3/envs/fingpt did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...\n",
      "  warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 8.6\n",
      "CUDA SETUP: Detected CUDA version 120\n",
      "CUDA SETUP: Loading binary /home/oliverwang15/miniconda3/envs/fingpt/lib/python3.9/site-packages/bitsandbytes-0.40.0.post4-py3.9.egg/bitsandbytes/libbitsandbytes_cuda120.so...\n"
     ]
    }
   ],
   "source": [
    "from typing import List, Dict, Optional\n",
    "\n",
    "import datasets\n",
    "import torch\n",
    "from loguru import logger\n",
    "from datasets import load_dataset\n",
    "from transformers import (\n",
    "    AutoModel,\n",
    "    AutoTokenizer,\n",
    "    TrainingArguments,\n",
    "    Trainer,\n",
    "    BitsAndBytesConfig\n",
    ")\n",
    "from peft import (\n",
    "    TaskType,\n",
    "    LoraConfig,\n",
    "    get_peft_model,\n",
    "    set_peft_model_state_dict,\n",
    "    prepare_model_for_kbit_training\n",
    ")\n",
    "from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "        output_dir='./finetuned_model',    # saved model path\n",
    "        logging_steps = 500,\n",
    "        # max_steps=10000,\n",
    "        num_train_epochs = 2,\n",
    "        per_device_train_batch_size=4,\n",
    "        gradient_accumulation_steps=8,\n",
    "        learning_rate=1e-4,\n",
    "        weight_decay=0.01,\n",
    "        warmup_steps=10000,\n",
    "        save_steps=5000,\n",
    "        fp16=True,\n",
    "        # bf16=True,\n",
    "        torch_compile = False,\n",
    "        load_best_model_at_end = True,\n",
    "        evaluation_strategy=\"steps\",\n",
    "        remove_unused_columns=False,\n",
    "\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    " # Quantization\n",
    "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
    "                                bnb_4bit_quant_type='nf4',\n",
    "                                bnb_4bit_use_double_quant=True,\n",
    "                                bnb_4bit_compute_dtype=torch.float16\n",
    "                                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are loading your model in 8bit or 4bit but no linear modules were found in your model. Please double check your model architecture, or submit an issue on github if you think this is a bug.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36ff88b6b6144a909c6d1b7e41405762",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load tokenizer & model\n",
    "model_name = \"THUDM/chatglm2-6b\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(\n",
    "        model_name, \n",
    "        quantization_config=q_config,\n",
    "        trust_remote_code=True, \n",
    "        device='cuda'\n",
    "    )\n",
    "model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LoRA\n",
    "target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']\n",
    "lora_config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM,\n",
    "    inference_mode=False,\n",
    "    r=8,\n",
    "    lora_alpha=32,\n",
    "    lora_dropout=0.1,\n",
    "    target_modules=target_modules,\n",
    "    bias='none',\n",
    ")\n",
    "model = get_peft_model(model, lora_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resume_from_checkpoint = None\n",
    "if resume_from_checkpoint is not None:\n",
    "    checkpoint_name = os.path.join(resume_from_checkpoint, 'pytorch_model.bin')\n",
    "    if not os.path.exists(checkpoint_name):\n",
    "        checkpoint_name = os.path.join(\n",
    "            resume_from_checkpoint, 'adapter_model.bin'\n",
    "        )\n",
    "        resume_from_checkpoint = False\n",
    "    if os.path.exists(checkpoint_name):\n",
    "        logger.info(f'Restarting from {checkpoint_name}')\n",
    "        adapters_weights = torch.load(checkpoint_name)\n",
    "        set_peft_model_state_dict(model, adapters_weights)\n",
    "    else:\n",
    "        logger.info(f'Checkpoint {checkpoint_name} not found')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 1,949,696 || all params: 3,390,261,248 || trainable%: 0.05750872447219737\n"
     ]
    }
   ],
   "source": [
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached split indices for dataset at /home/oliverwang15/new_model/instruct-FinGPT_v1/data/dataset_new/cache-53d0e90d10c68e06.arrow and /home/oliverwang15/new_model/instruct-FinGPT_v1/data/dataset_new/cache-55ed031d8eea1a57.arrow\n"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "dataset = datasets.load_from_disk(\"../data/dataset_new\")\n",
    "dataset = dataset.train_test_split(0.2, shuffle=True, seed = 42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModifiedTrainer(Trainer):\n",
    "    def compute_loss(self, model, inputs, return_outputs=False):\n",
    "        return model(\n",
    "            input_ids=inputs[\"input_ids\"],\n",
    "            labels=inputs[\"labels\"],\n",
    "        ).loss\n",
    "\n",
    "    def prediction_step(self, model: torch.nn.Module, inputs, prediction_loss_only: bool, ignore_keys = None):\n",
    "        with torch.no_grad():\n",
    "            res = model(\n",
    "                input_ids=inputs[\"input_ids\"].to(model.device),\n",
    "                labels=inputs[\"labels\"].to(model.device),\n",
    "            ).loss\n",
    "        return (res, None, None)\n",
    "\n",
    "    def save_model(self, output_dir=None, _internal_call=False):\n",
    "        from transformers.trainer import TRAINING_ARGS_NAME\n",
    "\n",
    "        os.makedirs(output_dir, exist_ok=True)\n",
    "        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n",
    "        saved_params = {\n",
    "            k: v.to(\"cpu\") for k, v in self.model.named_parameters() if v.requires_grad\n",
    "        }\n",
    "        torch.save(saved_params, os.path.join(output_dir, \"adapter_model.bin\"))\n",
    "\n",
    "def data_collator(features: list) -> dict:\n",
    "    len_ids = [len(feature[\"input_ids\"]) for feature in features]\n",
    "    longest = max(len_ids)\n",
    "    input_ids = []\n",
    "    labels_list = []\n",
    "    for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):\n",
    "        ids = feature[\"input_ids\"]\n",
    "        seq_len = feature[\"seq_len\"]\n",
    "        labels = (\n",
    "            [tokenizer.pad_token_id] * (seq_len - 1) + ids[(seq_len - 1) :] + [tokenizer.pad_token_id] * (longest - ids_l)\n",
    "        )\n",
    "        ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)\n",
    "        _ids = torch.LongTensor(ids)\n",
    "        labels_list.append(torch.LongTensor(labels))\n",
    "        input_ids.append(_ids)\n",
    "    input_ids = torch.stack(input_ids)\n",
    "    labels = torch.stack(labels_list)\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"labels\": labels,\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from transformers.integrations import TensorBoardCallback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are adding a <class 'transformers.integrations.TensorBoardCallback'> to the callbacks of this Trainer, but there is already one. The currentlist of callbacks is\n",
      ":DefaultFlowCallback\n",
      "TensorBoardCallback\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/oliverwang15/miniconda3/envs/fingpt/lib/python3.9/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3838' max='3838' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3838/3838 4:50:16, Epoch 1/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>15.781000</td>\n",
       "      <td>9.021257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>6.401400</td>\n",
       "      <td>6.019317</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>6.012100</td>\n",
       "      <td>5.969908</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>5.970300</td>\n",
       "      <td>5.936831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>5.799500</td>\n",
       "      <td>5.645917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>5.616000</td>\n",
       "      <td>5.634653</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>5.606800</td>\n",
       "      <td>5.628892</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "# Train\n",
    "writer = SummaryWriter()\n",
    "trainer = ModifiedTrainer(\n",
    "    model=model, \n",
    "    args=training_args,             # Trainer args\n",
    "    train_dataset=dataset[\"train\"], # Training set\n",
    "    eval_dataset=dataset[\"test\"],   # Testing set\n",
    "    data_collator=data_collator,    # Data Collator\n",
    "    callbacks=[TensorBoardCallback(writer)],\n",
    ")\n",
    "trainer.train()\n",
    "writer.close()\n",
    "# save model\n",
    "model.save_pretrained(training_args.output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GPU Trining MEM: 40.3%\n",
    "## GPU Evaling MEM: 73.7%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fingpt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
